{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "authorship_tag": "ABX9TyNKlk6MKeCAFiaFkcs9pvkB",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/blog/unittests_in_beam.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7DSE6TgWy7PP"
      },
      "outputs": [],
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Install the Apache Beam library\n",
        "!pip install apache_beam[interactive,gcp] --quiet"
      ],
      "metadata": {
        "id": "5W2nuV7uzlPg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "***Notebook Description***\n",
        "\n",
        "This colab notebook is referenced in the Unit Testing in Beam blog post, an opinionated guide on authoring unit tests for Apache Beam pipelines. Note that this code is used solely for conveying the best practices of unit testing, and thus is not intended to be used out of the box in user pipelines. The code references files native in colab, and thus executing the cells in order will provide the intended output.\n",
        "\n",
        "\n",
        "***Note: Running the cells does not invoke the tests written in this notebook. These could be run in a local IDE using [pytest](https://docs.pytest.org/en/stable/).***"
      ],
      "metadata": {
        "id": "_KqwMLN9kRXf"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Example 1**\n",
        "\n",
        "The following example shows how to use the `Map` construct to calculate median house value per bedroom.\n"
      ],
      "metadata": {
        "id": "IVjBkewt1sLA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import apache_beam as beam\n",
        "from apache_beam.io import ReadFromText, WriteToText\n",
        "\n",
        "# The following code computes the median house value per bedroom\n",
        "def median_house_value_per_bedroom(element):\n",
        "  # median_house_value is at index 8 and total_bedrooms is at index 4\n",
        "  element = element.strip().split(',')\n",
        "  return float(element[8])/float(element[4])\n",
        "\n",
        "\n",
        "with beam.Pipeline() as p1:\n",
        "    result = (\n",
        "        p1\n",
        "        | ReadFromText(\"/content/sample_data/california_housing_test.csv\",skip_header_lines=1)\n",
        "        | beam.Map(median_house_value_per_bedroom)\n",
        "        | WriteToText(\"/content/example2\")\n",
        "    )"
      ],
      "metadata": {
        "id": "8Y5bQXIyfhra"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Example 2**\n",
        "\n",
        "The following code is an extension of example 1, but with more complex pipeline logic. The `median_house_value_per_bedroom` function is now more complex, and involves writing to various keys."
      ],
      "metadata": {
        "id": "Mh3nZZ1_12sX"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import random\n",
        "# The following code computes the median house value per bedroom.\n",
        "counter=-1 #define a counter globally\n",
        "\n",
        "\n",
        "def median_house_value_per_bedroom(element):\n",
        "  # median_house_value is at index 8 and total_bedrooms is at index 4, all as part of they key \"1\".\n",
        "  global counter\n",
        "  element = element.strip().split(',')\n",
        "  # Create multiple keys based on different fields\n",
        "  keys = [1,2,3]\n",
        "  counter+=1\n",
        "  value = float(element[8]) / float(element[4])  # Calculate median house value per bedroom\n",
        "  return keys[counter%3],value\n",
        "\n",
        "def multiply_by_factor(element):\n",
        "  key,value=element\n",
        "  return (key,value*10)\n",
        "\n",
        "\n",
        "with beam.Pipeline() as p2:\n",
        "    result = (\n",
        "        p2\n",
        "        | ReadFromText(\"/content/sample_data/california_housing_test.csv\",skip_header_lines=1)\n",
        "        | beam.Map(median_house_value_per_bedroom)\n",
        "        | beam.Map(multiply_by_factor)\n",
        "        | beam.CombinePerKey(sum)\n",
        "        | WriteToText(\"/content/example3\")\n",
        "    )\n"
      ],
      "metadata": {
        "id": "hmO1Chl51vPG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "The following is code that is abstracting the two `Map` constructs, and the `CombinePerKey` construct. Note that this best practice makes our complex pipeline easier to test."
      ],
      "metadata": {
        "id": "cIuk2vT9sqOZ"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def transform_data_set(pcoll):\n",
        "  return (pcoll\n",
        "          | beam.Map(median_house_value_per_bedroom)\n",
        "          | beam.Map(multiply_by_factor)\n",
        "          | beam.CombinePerKey(sum))\n",
        "\n",
        "# Define a new class that inherits from beam.PTransform.\n",
        "class MapAndCombineTransform(beam.PTransform):\n",
        "  def expand(self, pcoll):\n",
        "    return transform_data_set(pcoll)\n",
        "\n",
        "with beam.Pipeline() as p2:\n",
        "   result = (\n",
        "       p2\n",
        "       | ReadFromText(\"/content/sample_data/california_housing_test.csv\",skip_header_lines=1)\n",
        "       | MapAndCombineTransform() # Use the new PTransform class\n",
        "       | WriteToText(\"/content/example3\")\n",
        "   )"
      ],
      "metadata": {
        "id": "XchihrbEqtlR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Unit Test for Pipeline 2**\n",
        "\n",
        "We've populated some sample records here, as well as set what we're expecting our expected value to be."
      ],
      "metadata": {
        "id": "uoNJLQl_15gj"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import unittest\n",
        "from apache_beam.testing.test_pipeline import TestPipeline\n",
        "from apache_beam.testing.util import assert_that, equal_to\n",
        "\n",
        "\n",
        "class TestBeam(unittest.TestCase):\n",
        "\n",
        "# This test corresponds to example 3, and is written to confirm the pipeline works as intended.\n",
        "  def test_transform_data_set(self):\n",
        "    expected=[(1, 10570.185786231425), (2, 13.375337533753376), (3, 13.315649867374006)]\n",
        "    input_elements = [\n",
        "      '-122.050000,37.370000,27.000000,3885.000000,661.000000,1537.000000,606.000000,6.608500,344700.000000',\n",
        "      '121.05,99.99,23.30,39.5,55.55,41.01,10,34,74.30,91.91',\n",
        "      '122.05,100.99,24.30,40.5,56.55,42.01,11,35,75.30,92.91',\n",
        "      '-120.05,39.37,29.00,4085.00,681.00,1557.00,626.00,6.8085,364700.00'\n",
        "    ]\n",
        "    with beam.Pipeline() as p2:\n",
        "      result = (\n",
        "                p2\n",
        "                | beam.Create(input_elements)\n",
        "                | beam.Map(MapAndCombineTransform())\n",
        "        )\n",
        "      assert_that(result,equal_to(expected))"
      ],
      "metadata": {
        "id": "BU9Eil-TrtpE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Example 3**\n",
        "\n",
        "This `DoFn` and the corresponding pipeline demonstrate a `DoFn` making an API call. An error occurs if the length of the API response (`returned_record`) is less than the length `10`."
      ],
      "metadata": {
        "id": "Z8__izORM3r8"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Fake client to simulate an external call\n",
        "\n",
        "import time\n",
        "class Client():\n",
        "   def get_data(self, api):\n",
        "      time.sleep(3)\n",
        "      return [0,1,2,3,4,5,6,7,8,9]\n",
        "\n",
        "MyApiCall = Client()"
      ],
      "metadata": {
        "id": "GGPF7cY3Ntyj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Note:** The following cell can take about 2 minutes to run"
      ],
      "metadata": {
        "id": "3tGnPucbzmEx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# The following packages are used to run the example pipelines.\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "\n",
        "class MyDoFn(beam.DoFn):\n",
        "  def process(self,element):\n",
        "          returned_record = MyApiCall.get_data(\"http://my-api-call.com\")\n",
        "          if len(returned_record)!=10:\n",
        "            raise ValueError(\"Length of record does not match expected length\")\n",
        "          yield returned_record\n",
        "\n",
        "with beam.Pipeline() as p3:\n",
        "  result = (\n",
        "          p3\n",
        "          | ReadFromText(\"/content/sample_data/anscombe.json\")\n",
        "          | beam.ParDo(MyDoFn())\n",
        "          | WriteToText(\"/content/example1\")\n",
        "  )"
      ],
      "metadata": {
        "id": "Ktk9EVIFzGfP"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "**Mocking Example**\n",
        "\n",
        "To test the error message, mock an API response, as demonstrated in the following blocks of code. Use mocking to avoid making the actual API call in the test."
      ],
      "metadata": {
        "id": "58GVMyMa2PwE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "!pip install mock  # Install the 'mock' module."
      ],
      "metadata": {
        "id": "ESclJ_G-6JcW"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Import the mock package for mocking functionality.\n",
        "from unittest.mock import Mock,patch\n",
        "# from MyApiCall import get_data\n",
        "import mock\n",
        "\n",
        "\n",
        "# MyApiCall is a function that calls get_data to fetch some data by using an API call.\n",
        "@patch('MyApiCall.get_data')\n",
        "def test_error_message_wrong_length(self, mock_get_data):\n",
        " response = ['field1','field2']\n",
        " mock_get_data.return_value = Mock()\n",
        " mock_get_data.return_value.json.return_value=response\n",
        "\n",
        " input_elements = ['-122.050000,37.370000,27.000000,3885.000000,661.000000,1537.000000,606.000000,6.608500,344700.000000'] #input length 9\n",
        " with self.assertRaisesRegex(ValueError,\n",
        "                             \"Length of record does not match expected length'\"):\n",
        "     p3 = beam.Pipeline()\n",
        "     result = p3 | beam.create(input_elements) | beam.ParDo(MyDoFn())\n",
        "     result\n"
      ],
      "metadata": {
        "id": "IRuv8s8a2O8F"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}
